# Copyright (c) OpenMMLab. All rights reserved.
from dataclasses import dataclass, field
from typing import Union

from transformers import TrainingArguments
from transformers.trainer_utils import IntervalStrategy, SchedulerType

__all__ = ["DefaultTrainingArguments"]


@dataclass
class DefaultTrainingArguments(TrainingArguments):
    # custom
    model_name_or_path: str = field(
        default=None,
        metadata={"help": "model name or path."},
    )
    dataset_name_or_path: str = field(
        default=None,
        metadata={"help": "dataset name or path."},
    )

    # huggingface
    default_output_dir = "./work_dirs"
    default_do_train = True
    default_per_device_train_batch_size = 1
    default_learning_rate = 2e-5
    default_save_strategy = "epoch"
    default_lr_scheduler_type = "cosine"
    default_logging_steps = 5

    output_dir: str = field(
        default=default_output_dir,
        metadata={
            "help": (
                "The output directory where the model predictions and "
                "checkpoints will be written."
            )
        },
    )
    do_train: bool = field(
        default=default_do_train, metadata={"help": "Whether to run training."}
    )
    per_device_train_batch_size: int = field(
        default=default_per_device_train_batch_size,
        metadata={"help": "Batch size per GPU/TPU core/CPU for training."},
    )
    learning_rate: float = field(
        default=default_learning_rate,
        metadata={"help": "The initial learning rate for AdamW."},
    )
    save_strategy: Union[IntervalStrategy, str] = field(
        default=default_save_strategy,
        metadata={"help": "The checkpoint save strategy to use."},
    )
    lr_scheduler_type: Union[SchedulerType, str] = field(
        default=default_lr_scheduler_type,
        metadata={"help": "The scheduler type to use."},
    )
    logging_steps: float = field(
        default=default_logging_steps,
        metadata={
            "help": (
                "Log every X updates steps. Should be an integer or a "
                "float in range `[0,1)`. If smaller than 1, will be "
                "interpreted as ratio of total training steps."
            )
        },
    )
